#!/usr/bin/env python3

# 
# Run the test, compare results against the benchmark
#

#requires: all_tests
#Requires: hdf5

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boutdata.collect import collect
import numpy as np
from sys import stdout, exit



print("Making I/O test")
shell_safe("make > make.log")

# Read benchmark values

vars = ['ivar', 'rvar', 'bvar', 'f2d', 'f3d', 'fperp', 'fperp2', 'ivar_evol', 'rvar_evol',
        'bvar_evol', 'v2d_evol_x', 'v2d_evol_y', 'v2d_evol_z', 'fperp2_evol']

field_vars = ['f2d', 'f3d', 'fperp', 'fperp2', 'v2d_evol_x', 'v2d_evol_y', 'v2d_evol_z',
        'fperp2_evol'] # Field quantities, not scalars

tol = 1e-6

print("Reading benchmark data")
bmk = {}
for v in vars:
  bmk[v] = collect(v, path="data", prefix="benchmark.out", info=False)

print("Running I/O test")
success = True
for nproc in [1,2,4]:
  cmd = "./test_io_hdf5"

  # On some machines need to delete dmp files first
  # or data isn't written correctly
  shell("rm data/BOUT.dmp.*") 

  # Run test case

  print("   %d processor...." % (nproc))
  s, out = launch_safe(cmd, nproc=nproc, pipe=True)
  with open("run.log."+str(nproc), "w") as f:
    f.write(out)

  # Collect output data
  for v in vars:
    stdout.write("      Checking variable "+v+" ... ")
    result = collect(v, path="data", info=False)

    # Compare benchmark and output
    if np.shape(bmk[v]) != np.shape(result):
      print("Fail, wrong shape")
      success = False
      continue

    diff =  np.max(np.abs(bmk[v] - result))
    if diff > tol:
      print("Fail, maximum difference = "+str(diff))
      success = False
      continue

    if v in field_vars:
      # Check cell location
      if "cell_location" not in result.attributes:
        print("Fail: {0} has no cell_location attribute".format(v))
        success = False
        continue

      if result.attributes["cell_location"] != "CELL_CENTRE":
        print("Fail: Expecting cell_location == CELL_CENTRE, but got {0}".format(result.attributes["cell_location"]))
        success = False
        continue

    print("Pass")

if success:
  print(" => All I/O tests passed")
  exit(0)
else:
  print(" => Some failed tests")
  exit(1)
